from itertools import product
from functools import reduce
from operator import add
from collections import deque
from typing import NamedTuple
from dataclasses import dataclass

import numpy as np

from qiskit import QuantumCircuit, QuantumRegister
from qiskit.circuit.library import EfficientSU2
from qiskit.primitives import Estimator

from qiskit.quantum_info import state_fidelity, SparsePauliOp, Statevector
from qiskit.circuit import Parameter

from ..utils import expand_params
from ..energy import EnergyBackend, register_backend


class PauliOp(NamedTuple):
    string: str
    coeff: float


def generic_circuit(n_qbits, n_layers):
    """Construct a generic quantum circuit

    Parameters
    ----------
    n_qbits: Int
        Total number of qbits in the system
    n_layers: Int
        Total number of (entangling and rotation) layers in the ansatz
    """
    reg = QuantumRegister(n_qbits)
    circuit = QuantumCircuit(reg)
    params = [[Parameter(f'θ[{gate}, {qbit}]') for qbit in range(n_qbits)] for gate in range(2 + n_layers * 4)]

    for i in range(n_qbits):
        circuit.rx(params[0][i], reg[i])
        circuit.rz(params[1][i], reg[i])
    circuit.barrier()
    for layer in range(n_layers):
        offset = 2 + layer * 4
        for i in range(n_qbits):
            circuit.rxx(params[offset][i], reg[i], reg[(i + 1) % n_qbits])
        circuit.barrier()
        for i in range(n_qbits):
            circuit.rx(params[offset + 1][i], reg[i])
            circuit.rz(params[offset + 2][i], reg[i])
            circuit.ry(params[offset + 3][i], reg[i])
    return circuit


@register_backend('qiskit')
@dataclass
class QiskitBackend(EnergyBackend):
    """Qiskit-based backend for VQE."""
    name = 'qiskit'

    def __post_init__(self):
        self.qcircuit = self.make_circuit()
        self.hamiltonian = self.heisenberg_hamiltonian()

    def heisenberg_hamiltonian_terms(self):
        """Constructs the summands of the Heisenberg Hamiltonian with arbitrary external fields.

        Returns
        -------
        summands: obj:`qiskit.opflow.PauliOp`
            The summands of the hamiltonian.
        """
        summands = []
        for op, s_coeff, p_coeff in zip(('X', 'Y', 'Z'), self.h, self.j):
            single = deque(op + 'I' * (self.n_qbits - 1))
            paired = deque(op + op + 'I' * (self.n_qbits - 2))
            for _ in range(self.n_qbits):
                summands.append(PauliOp(reduce(add, single), s_coeff))
                summands.append(PauliOp(reduce(add, paired), p_coeff))
                single.rotate()
                paired.rotate()
            if not self.pbc:
                del summands[-1]

        return [elem for elem in summands if elem.coeff != 0.0]

    def heisenberg_hamiltonian(self):
        """Constructs the Heisenberg Hamiltonian with arbitrary external fields.

        Returns
        -------
        PauliOp: obj:`qiskit.opflow.PauliOp`
            The hamiltonian.
        """
        return SparsePauliOp.from_list(self.heisenberg_hamiltonian_terms())

    def assign_angles(self, angles):
        """Returns the circuit with gate paramters assigned `angles`."""
        return self.qcircuit.assign_parameters(angles.flatten())

    def make_circuit(self):
        """Creates the quantum circuit with unassigned parameters."""
        reg = QuantumRegister(self.n_qbits, 'q')
        ansatz = QuantumCircuit(reg)
        if self.mom_sector == -1:
            # ansatz.pauli('X' * (n_qbits // 2), [register[i] for i in range(1, n_qbits, 2)])
            ansatz.pauli(('IX' * ((self.n_qbits + 1) // 2))[:self.n_qbits], reg)
        elif self.mom_sector != 1:
            raise RuntimeError(f'Unsupported sector: \'{self.mom_sector}\'')

        if self.circuit == 'generic':
            qcircuit = generic_circuit(self.n_qbits, self.n_layers)
        elif self.circuit == 'esu2':
            qcircuit = EfficientSU2(
                num_qubits=self.n_qbits,
                su2_gates=None,
                entanglement='full',
                insert_barriers=True,
                reps=self.n_layers,
                parameter_prefix='θ'
            )
        else:
            raise RuntimeError(f'Qiskit only supports circuits \'generic\' and \'esu2\', got \'{self.circuit}\'')

        return ansatz.compose(qcircuit)

    def measure_energy(self, angles, n_readout):
        """Measure the expected energy and variance of the hamiltonian.

        Paramaters
        ----------
        angles: :py:obj:`np.ndarray`
            The gate parameters.
        n_readout: int
            Number of shots (readouts) used to measure the hamiltonian.

        Returns
        -------
        energies: :py:obj:`np.ndarray`
            The mean energy values over the shots.
        variances: :py:obj:`np.ndarray`
            The variance of the energy values over the shots.

        """
        # TODO: Needs to implement the support of noise level as in compute energy.
        angles = expand_params(np.array(angles), self.n_qbits)
        seed_simulator = int.from_bytes(self.rng.bytes(4), 'big')

        estimator = Estimator()
        energies, variances = [], []

        if n_readout <= 0:
            n_readout = None

        for batch in angles:
            job = estimator.run(self.assign_angles(batch), self.hamiltonian, shots=n_readout, seed=seed_simulator)

            energies.append(job.result().values[0])
            if n_readout is None:
                variances.append(0.0)
            else:
                variances.append(job.result().metadata[0]['variance'] / n_readout)

        return np.array(energies), np.array(variances)

    def measure_overlap(self, angles, exact_wf):
        """Computes the overlap between a state vector and the resulting state vector given a circuit and its angles.

        Parameters
        ----------
        exact_wf: obj:`numpy.ndarray`
            State vector as a complex numpy array
        circuit: QuantumCircuit
            Quantum circuit encoding the wave function
        """
        angles = expand_params(np.array(angles), self.n_qbits)
        return np.stack([
            state_fidelity(exact_wf, Statevector(self.assign_angles(batch)))
            for batch in angles
        ])

    def parameter_shift_gradient(self, angles, n_readout):
        """Compute the gradient of the mean energy wrt. the gate parameters `angles` using the parameter shift rule.

        Paramaters
        ----------
        angles: :py:obj:`np.ndarray`
            The gate parameters.
        n_readout: int
            Number of shots (readouts) used to measure the hamiltonian.

        Returns
        -------
        energies: :py:obj:`np.ndarray`
            The gradient of the mean energy values over the shots wrt. the gate parameters `angles`.

        Note
        ----
        A future release may compute the variance also, in case it is needed.
        The variance of the parameter shift requires covariances of the individual shots, which is not possible to get
        through the current way of the energy estimation.

        """
        angles = expand_params(np.array(angles), self.n_qbits)
        angles = np.array(angles)

        def measure(angles):
            mean, _ = zip(*(self.measure_energy(batch[None], n_readout) for batch in angles))
            return np.concatenate(mean, axis=0)

        grad = np.zeros_like(angles)
        for index in product((slice(None),), *(range(s) for s in angles.shape[1:])):
            org = angles[index].copy()
            angles[index] += np.pi / 2.
            grad[index] = measure(angles) / 2.
            angles[index] -= np.pi
            grad[index] -= measure(angles) / 2.
            angles[index] = org

        return grad
